import numpy as np
import tensorflow as tf

flags = tf.app.flags

# 定义角色名称
flags.DEFINE_string('job_name', None, 'job name: worker or ps')
# 指定任务的编号
flags.DEFINE_integer('task_index', None, 'Index of task within the job')

# 定义ip和端口
flags.DEFINE_string('ps_hosts', 'localhost:1681', 'Comma-separated list of hostname:port pairs')
flags.DEFINE_string('worker_hosts', 'localhost:1682,localhost:1683', 'Comma-separated list of hostname:port pairs')
# 定义保存文件的目录
flags.DEFINE_string('log_dir', 'log/super/', 'directory path')

# 参数设置
flags.DEFINE_integer('training_epochs', 20, 'training epochs')

FLAGS = flags.FLAGS

# 生成模拟数据
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3  # y=2x，但是加入了噪声

tf.reset_default_graph()

ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
# 创建server
server = tf.train.Server({'ps': ps_hosts, 'worker': worker_hosts},
                         job_name=FLAGS.job_name,
                         task_index=FLAGS.task_index)

# ps角色使用join进行等待
if FLAGS.job_name == 'ps':
    print("waiting...")
    server.join()

with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster_spec)):
    X = tf.placeholder("float")
    Y = tf.placeholder("float")
    # 模型参数
    W = tf.Variable(tf.random_normal([1]), name="weight")
    b = tf.Variable(tf.zeros([1]), name="bias")

    global_step = tf.contrib.framework.get_or_create_global_step()  # 获得迭代次数

    # 前向结构
    z = tf.multiply(X, W) + b
    # 反向优化
    cost = tf.reduce_mean(tf.square(Y - z))
    learning_rate = 0.01
    # Gradient descent
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step)

    init = tf.global_variables_initializer()


sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                         init_op=init,
                         global_step=global_step)

# 连接目标角色创建session
with sv.managed_session(server.target) as sess:
    print(global_step.eval(session=sess))

    for epoch in range(global_step.eval(session=sess), FLAGS.training_epochs*len(train_X)):

        for (x, y) in zip(train_X, train_Y):
            _, epoch = sess.run([optimizer, global_step], feed_dict={X: x, Y: y})

            loss = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", epoch + 1, "cost=", loss, "W=", sess.run(W), "b=", sess.run(b))

    print(" Finished!")
sv.stop()